#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: matthewroberts
"""

import os
import sys
import numpy as np
import glob
import datetime as dt

import matplotlib.pyplot as plt
import matplotlib.path as path
import matplotlib.cm as cm
from scipy import ndimage
from scipy.interpolate import griddata

from netCDF4 import Dataset

import warnings
warnings.filterwarnings('ignore')

start_time = dt.datetime.now()
print('\nProgram Started: {0}'.format(start_time))
print('===================')

# --------------------------------------------------------
# --------------------------------------------------------


def refine_dom(lon, lat, rr):

    # Calculates the fire refined domain (not available in wrfinput_d0x)
    # rr is the refinement ratio, 0 < rr < 1
    # rr should be 1/sr_x (sr_x is in namelist.input)

    print('Refining domain...')
    nx, ny = lon.shape
    i, j = np.meshgrid(np.arange(0, nx, 1), np.arange(0, ny, 1), indexing='ij')

    ii, jj = np.meshgrid(np.arange(-4*rr, nx + 4*rr, rr),
                         np.arange(-4*rr, ny + 4*rr, rr), indexing='ij')

    xxlon = griddata(np.array([i.flatten(), j.flatten()]).T,
                     lon.flatten(), (ii, jj),
                     method='linear')
    xxlat = griddata(np.array([i.flatten(), j.flatten()]).T,
                     lat.flatten(), (ii, jj),
                     method='linear')

    return (xxlon, xxlat)
    
# --------------------------------------------------------
# --------------------------------------------------------
def mask_perim(perim, coords):

    # Creates a mask for the perimeter: inside is true

    lon, lat = coords

    mp = path.Path(perim, closed=True)
    points = np.array((lon.flatten(), lat.flatten())).T
    mask = mp.contains_points(points).reshape(lat.shape)

    return mask
# --------------------------------------------------------
# --------------------------------------------------------
def read_json_perim(filename):

    fjson = filename
    print(fjson)

    import json
    with open(fjson) as f:
        gj = json.load(f)['features']
    perim = [i for i in gj if i['geometry'] is not None]
    return perim
# --------------------------------------------------------
# --------------------------------------------------------
def write_wrffile(var, data, fname, **kwargs):

    # Open WRF input file and write in variables

    nx, ny = data.shape

    with Dataset(fname, 'r+') as fnc:
        # print(fnc.variables[var].shape)
        # print(data.shape)
        fnc.variables[var][:] = 1.0
        fnc.variables[var][0, :nx, :ny] = data

    return
# --------------------------------------------------------
# --------------------------------------------------------

# Time interval between perimeters
perim_interval = 300 #seconds
# How many seconds into simulation to start fire
fire_start_time = 1800

filepath = '../radar_perims_json/'
perimfiles = sorted(glob.glob(filepath+'bear_*.geojson'))
wrffile = './wrfinput_d03_forced'

data = Dataset(wrffile,mode='r')
xlat = data.variables['XLAT'][:].squeeze()
xlon = data.variables['XLONG'][:].squeeze()
tign_g = (data.variables['TIGN_G'][:].squeeze()*0.)+999999.
fuel = data.variables['NFUEL_CAT'][:].squeeze()
data.close()

fxlon, fxlat = refine_dom(xlon, xlat, 0.2506)
fxnx, fxny = fxlon.shape

#%%
count=0
for p in perimfiles:
    
    gj = read_json_perim(p)
    if gj[0]['geometry']['type'] == 'MultiPolygon':
        # --------------------------------------------------------
        # Multipolygon - 1 feature with multiple coordinate sets
        # --------------------------------------------------------  
        print('Multipolygon with one feature')
    
        mpoly = gj[0]['geometry']['coordinates']
    
        perim = []
        for kk, ii in enumerate(mpoly):
            perim.append(np.squeeze(np.array(ii)))
            print(f'polygon {kk}: {perim[kk].shape}')
    else:
        # --------------------------------------------------------
        # Single or Multipolygon (multiple features each with one coordinate set)
        # --------------------------------------------------------
        print('Single polygon or Multiple features')
    
        perim = []
        for kk, ii in enumerate(gj):
            perim.append(np.squeeze(np.array(ii['geometry']['coordinates'])))
            print(f'polygon {kk}: {perim[kk].shape}')
        
    # Mask perimeter polygons - inside is True
    # assign value of "1"
    per2d = [mask_perim(ii, [fxlon, fxlat]) for ii in perim]
    per2d = np.asarray(per2d) * perim_interval
    
    # accounting for multiple polygons
    if (per2d.ndim > 2):
        new_tign2 = np.zeros(np.shape(per2d[0,:,:]))
        for i in range(len(per2d[:,0,0])):
            new_tign2 = per2d[i,:,:] + new_tign2
        per2d = new_tign2
        
    if (count == 0):
        # add each time to ignition grid
        new_tign = tign_g-per2d
    else:
        # add each time to ignition grid
        new_tign = new_tign-per2d
    
    count=count+1

#%%
# sanity checks
print('old tign: '+str(np.nanmin(new_tign)))
# makes first perim start at 0
new_tign = new_tign-np.min(new_tign) 
print('new tign: '+str(np.nanmin(new_tign)))
# start fire at 30 min into sim.
new_tign = new_tign+fire_start_time 
print('corrected tign: '+str(np.nanmin(new_tign)))

# write data to file
write_wrffile('TIGN_G', new_tign, wrffile)

####################################
print('Plotting Domains...')
fig = plt.figure(figsize=(10, 7))
# plot tign_g in minutes since simulation start
fplot = plt.pcolormesh( new_tign/60., 
                        vmax=(np.nanmax(new_tign/60.)-1.),
                        alpha=.8 )
plt.colorbar(fplot)
plt.savefig("../WRF/test/em_real/plot_tign.png")
plt.close('all')

####################################

cmap1 = cm.get_cmap('terrain',lut=54)

fig = plt.figure(figsize=(10, 7))
# plot fuel categories
fuelplot = plt.pcolormesh( fuel, cmap=cmap1 )
plt.colorbar(fuelplot)
plt.savefig("../WRF/test/em_real/plot_fuel.png")
plt.close('all')
####################################

print('Test plots saved...')

sys.exit()

